Skip to content

Attention Mechanisms and Its Variants

Introduction to Attention Mechanisms

  • Attention is a powerful mechanism that enables models to focus on specific parts of the input sequence, enhancing their ability to understand and process complex patterns.
  • Over time, researchers have developed several variants of attention mechanisms to further improve model performance and efficiency.
  • These include
    • Multi-Query Attention (MQA),
    • Grouped Query Attention (GQA),
    • Sliding Window Attention,
    • Flash Attention,
    • Paged Attention, and more.

Each of these variants brings unique advantages and is suited to different types of tasks and challenges in the field of deep learning. In the following sections, we will delve into the details of these attention variants, exploring their characteristics, applications, and impact on model performance.

MHA: Multi-Head Attention

  • Before Multi-Head Attention (MHA), there was only one head.
  • The idea of MHA is to have multiple heads. Each head can focus on different parts of the input sequence.
    • This allows the model to learn different types of relationships between tokens.
    • For example, one head might focus on the subject of a sentence, while another head focuses on the object.
    • This can help the model to better understand the structure of the input sequence and make more accurate predictions.
  • MHA is a key component of the transformer architecture and has been used in many state-of-the-art natural language processing models, such as BERT and GPT-3.
  • Attention is a mechanism that allows the model to focus on specific parts of the input sequence.
  • It is a key component of the transformer architecture.
  • Equation: \(Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V\)

Note

MHA has some limitations

For example, it has a quadratic compute and memory complexity with respect to the input sequence length. This means that computation grows in proportion to the square of the sequence length. The original transformer has short sequences (e.g., Original Transformer, BERT = 512). But now, sequences of length 5k, 8k,32k are common. So inference (e.g., RAG application) becomes more expensive. Therefore, innovation in attention mechanisms is necessary.

when considering the standard self-attention mechanism, the time complexity is O(n^2·d), where:

  • n represents the number of input tokens (or sequence length).
  • d denotes the dimensionality of the vector representations.

This quadratic complexity arises due to the pairwise token operations required by self-attention.

AttentionComplexityfromTransformerpaper Reference: Attention is all you need

MQA: Multi-Query Attention

  • Implemented in Falcon 7B
  • Multiple Queries, but common K,V for all queries. Much smaller KV cache.
  • 12x faster decoding during inference.
  • Less computation time than MHA.
  • Reduced memory usage: batch size can be increased.
  • Model quality (Perplexity): minor degradation compared to MHA.
  • Model must be trained with MQA. Can't use if pre-trained.
  • Tensor parallelism? There is nothing to split. KC are the same. So, we have to replicate them. They need to be present on all nodes.
  • Paper: Fast Transformer Decoding: One write-Head is all you need.

GQA: Grouped Query Attention

  • Implemented in Llama 2, Mistral
  • MHA model can be uptrained to GQA. Some additional training on MHA. (Not fine-tuning) Not required full retraining.
  • Better fit for tensor parallelism. As you have multiple K,V. You can split them across GPUs. You can make better use of hardware.
  • Trade-off between quality and speed between MHA and MQA.
  • Subgroup of queries. For each subgroup common K,V.
  • If the number of groups = 1,
    • it will become MQA.
    • All queries are in one group. Common KV for all queries.
  • If the number of groups = number of heads,
    • it's MHA.
    • High inference time, high computation, higher performance.
  • Paper: GQA: Training Generalized Multi-Query Transformer.

Sliding Window Attention

  • Implemented in Mistral

  • Motivation:

    • In attention-based models (such as Transformers), non-sparse attention has a self-attention component with O(n^2) time and memory complexity, where n is the input sequence length.
    • This complexity makes it inefficient to scale to long inputs.

Sliding Window Attention in Mistal Image: Sliding Window Attention in Mistral

  • What is Sliding Window Attention?:
    • Sliding Window Attention (SWA) is a technique used in transformer models to limit the attention span of each token.
    • It focuses on a fixed-size window around each token, capturing local context efficiently.
  • How It Works:
    • Given a fixed window size w, each token attends to its neighboring tokens within the window.
    • The computation complexity of this pattern is O(n × w), which scales linearly with input sequence length.
    • To maintain efficiency, w should be small compared to n.
  • Receptive Field:

    • Stacking multiple layers of such windowed attention results in a large receptive field.
    • Top layers have access to all input locations, incorporating information across the entire input.
  • Balancing Efficiency and Representation Capacity:

    • Depending on the application, different values of w can be used for each layer.
    • This balances between efficiency and model representation capacity.
  • Applications:
    • SWA has applications in various tasks, including natural language processing and computer vision.
    • Remember, SWA allows attention-based models to efficiently process long sequences while maintaining relevant context! 🚀

Other:

  • Flash Attention
  • Paged Attention (Most recent 9/23)

Reference:

YouTube Video